import pandas as pd
import matplotlib.pyplot as plt

frame = pd.read_csv(r'./data/BlackFriday.csv')
results = frame[['Purchase']].groupby(frame['age']).mean()

plt.scatter(results['Purchase'], results.index, c=results['Purchase'], s=results['Purchase'])
plt.colorbar()
plt.show()

results = frame[['Purchase']].groupby(frame['Stay_In_Current_City_Years']).mean()
plt.scatter(results['Purchase'], results.index, c=results['Purchase'], s=results['Purchase'])
plt.colorbar()
plt.show()

from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
results = frame.groupby([frame['Stay_In_Current_City_Years'], frame['Age']])
ax = plt.subplot(projection='3d')
ax.scatter(results['Purchase'], results.index.codes[0], results.index.codes[1], color=cm.ScalarMappable()
           .to_rgba(results['Purchase']), s=200)
sm = plt.cm.ScalarMappable(norm=plt.Normalize(vmin=results['Purchase'].min(), vmax=results['Purchase'].max()))
plt.colorbar()
plt.show()

pd.set_option('display.max_columns', None)
pd.set_option('display.width', 500)
print(frame.corr())

print(frame['Age'].apply(lambda x: x[0]))
frame['Age'] = frame['Age'].apply(lambda x: x[0])
frame.loc[frame['Stay_In_Current_City_Years'] == '4+', 'Stay_In_Current_City_Years'] = 4
frame['Age'] = frame['Age'].astype(int)
frame['Stay_In_Current_City_Years'] = frame['Stay_In_Current_City_Years'].astype(int)
print(frame.corr())


